Open
Conversation
Perseus14
reviewed
Apr 19, 2026
Perseus14
reviewed
Apr 20, 2026
entrpn
reviewed
Apr 20, 2026
efbbdc8 to
79fd839
Compare
79fd839 to
77973e3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds features which lead to performance gains for LTX-2 model, along with a fix for the broken LTX-2 Upsampler in main
Performance enhancement features:
Experiments performed on v7x-8
Sharding fix in
NNXSImpleFeedForward: Data is sharded along the sequence dimension, each device holds a subset of tokens, but full feature channels. Because the input data had replicated features but the weights expected sharded features on the same physical axis (context), XLA was forced to insert an All-Gather on the sequence dimension to resolve the layout conflict, resulting in high wasted time. With our fix:QKV Projection Sharding fix (ironwood specific): The profiling showed that the input data was being All-Gathered along the sequence dimension triggered by the QKV Projection step. Because the weights were sharded on the dimension that needed to be summed over (features), a single device could not complete the matrix multiplication using only its local shard of the data. To resolve this, XLA automatically inserted an All-Gather to replicate the sequence dimension across all devices before performing the multiplication. We changed the weight sharding in
attention_ltx2.pyto remove sharding on the input feature dimension.Batching in text encoder: With CFG enabled, we see two passes of text encoder: one each for positive and negative prompts. If Classifier-Free Guidance is enabled, we concatenate the positive prompt and negative prompt and instead of doing two passes of text encoder, we do a single pass.
JITting Diffusion Loop: The current implementation uses a Python for loop to iterate over diffusion timesteps. This created a "Python Dispatch Wall," resulting in some idle time between consecutive forward passes while the TPU waited for the host CPU to dispatch the next step. We refactored the entire denoising loop to use nnx.scan.
LTX2 Upsampler fix:
Results
v7x-8
We also tested WAN I2V pipelines to ensure no regressions are caused there. No quality regression or increased latency was observed.